package it.uniroma2.dtk.dt.partial;

import it.uniroma2.dtk.common.Spectrum;
import it.uniroma2.dtk.dt.DefaultAbstractDT;
import it.uniroma2.util.math.ArrayMath;
import it.uniroma2.util.math.MatrixUtils;
import it.uniroma2.util.tree.Tree;
import it.uniroma2.util.vector.VectorComposer;

import java.util.HashMap;
import java.util.List;

/**
 * @author Lorenzo Dell'Arciprete
 * Implementation of the DT interface for the Partial Tree Kernel presented in:
 * 		A. Moschitti. 
 * 		Efficient convolution kernels for dependency and constituent syntactic trees. 
 * 		In Proceedings of ECML'06, pages 318--329, 2006. 
 * For the sake of uniformity with the other DTs, parameters lambda and mu are swapped with respect to their use in the paper.
 * The weight of a tree fragment is lambda^(n/2)*mu^m where:
 * 	n is the number of nodes in the fragment;
 * 	m is the number of terminal nodes + the sum of |c|-1 for every production in the fragment.
 * For a production, |c| is the length of the children sequence, including gaps with respect to the original tree production.
 * This kernel includes single nodes as fragments as well. 
 * This class abstracts from the actual ideal composition function implementation.
 */
/**
 * @author lorenzo
 *
 */
public abstract class AbstractPartialDT extends DefaultAbstractDT {

	public AbstractPartialDT(int randomOffset, int vectorsSize, boolean usePos, boolean lexicalized, double lambda, double mu) throws Exception {
		this(randomOffset, vectorsSize, usePos, lexicalized);
		setLambda(lambda);
		setMu(mu);
	}
	
	public AbstractPartialDT(int randomOffset, int vectorsSize, boolean usePos, boolean lexicalized) throws Exception {
		super(randomOffset, vectorsSize, usePos, lexicalized);
	}
	
	protected final int MAX_STATIC_POWERS = 10;
	protected double mu = 1;
	protected double[] muPows = new double[MAX_STATIC_POWERS];
	
	public void setMu(double mu) {
		this.mu = mu;
		muPows[0] = 1;
		for (int i=1; i<MAX_STATIC_POWERS; i++)
			muPows[i] = mu*muPows[i-1];
	};
	
	public double getMu() {
		return mu;
	};

	protected double getMuPow(int exp) {
		if (exp < MAX_STATIC_POWERS)
			return muPows[exp];
		else
			return Math.pow(mu, exp);
	}
	
	/**
	 * Recursive computation of function s(n) for the root of the input tree.
	 * To save time and space, the computed s(n) is directly added to the final sum.
	 * s(n) sums all of the tree fragments rooted in n.
	 * @param node - the input tree or, equivalently, its root node
	 * @param sum - the object collecting the sum of s(n) for each node in the tree
	 * @return s(node)
	 * @throws Exception
	 */
	@Override
	protected double[] sRecursive(Tree node, Spectrum sum) throws Exception {
		double[] n = getLabelVector(node);
		double[] result;
		if (node.isTerminal() || (!lexicalized && node.isPreTerminal()))
			result = ArrayMath.scalardot(mu, n);
		else
			result = ArrayMath.sum(ArrayMath.scalardot(mu, n), op(n, d(node.getChildren(), sum)));
		result = ArrayMath.scalardot(lambdaSq, result);
		sum.setVector(VectorComposer.sum(sum.getVector(), result));
		return result;
	}
	
	@Override
	public double[] dtf(Tree x) {
		System.out.println("WARNING: dtf(Tree x) cannot be correctly computed by itself for the " +
				"Partial Tree Kernel!\nUse dtf(Tree x, Tree original) instead!");
		return null;
	}
	
	/**
	 * This computes a DTF with respect to its originary tree. 
	 * Parameters lambda and mu are also considered.
	 * @param x - The tree fragments
	 * @param original - The originary tree
	 * @return - DTF(x)
	 */
	public double[] dtf(Tree x, Tree original) throws Exception {
		Tree superTree = findSuperTree(x, original);
		if (superTree == null) {
			throw new Exception("Fragment not found in originary tree!");
		}
		double[] result = MatrixUtils.uniformVector(vectorSize, 0);
		if (x.isTerminal()) {
			//In this case, return value will be mistakenly multiplied by lambdaSq, 
			//so it must be divided by lambdaSq to compensate 
			result = ArrayMath.scalardot(lambdaSq*mu, getLabelVector(x));
		}
		else {
			//The composition order is different from the one of the classic DTK, it is n#(c1#(c2#...#(cn-1#cn)...))
			result = dtf(x.getChildren().get(x.getChildren().size()-1), superTree);
			for (int i=x.getChildren().size()-2; i >= 0; i--)
				result = op(dtf(x.getChildren().get(i), superTree), result);
			result = ArrayMath.scalardot(lambdaSq, op(getLabelVector(x), result));
		}
		return result;
	}
	
	protected Tree findSuperTree(Tree fragment, Tree whole) throws Exception {
		if (isSuperTree(fragment, whole))
			return whole;
		else {
			Tree superTree = null;
			for (Tree c : whole.getChildren()) {
				if (superTree != null) {
					if (findSuperTree(fragment, c) != null)
						throw new Exception("Tree fragment may refer to multiple subtrees!");
				}
				else
					superTree = findSuperTree(fragment, c);
			}
			return superTree;
		}
	}
	
	protected boolean isSuperTree(Tree fragment, Tree whole) {
		if (fragment.getRootLabel().equals(whole.getRootLabel())) {
			if (fragment.isTerminal())
				return true;
			else {
				int i = -1;
				for (Tree c : fragment.getChildren()) {
					do {
						i++;
						if (i >= whole.getChildren().size())
							return false;
					} while(!isSuperTree(c, whole.getChildren().get(i)));
				}
				return true;
			}
		}
		else
			return false;
	}
	
	/**
	 * Computation of D(c). HashMap dValues is used for dynamic programming.
	 * D(c) sums all of the tree fragment forests rooted in any subset of nodes in c, to be attached to the parent node.
	 * @param c - the list of children nodes for the parent node.
	 * @param sum - the object collecting the sum of s(n) for each node in the tree.
	 * @return D(c)
	 * @throws Exception
	 */
	protected double[] d(List<Tree> c, Spectrum sum) throws Exception {
		HashMap<Integer, double[]> dValues = new HashMap<Integer, double[]>();
		double[] result = dRecursive(c, 0, dValues, sum);
		for (int i=1; i<c.size(); i++)
			result = ArrayMath.sum(result, dRecursive(c, i, dValues, sum));
		return result;
	}
	
	/**
	 * Computation of d(c_i). Dynamic programming is used for efficiency reasons.
	 * d(c_i) sums all of the tree fragment forests rooted in c_i and any subset of nodes in c following c_i.
	 * @param c - the list of children nodes for the parent node.
	 * @param i - the current child index.
	 * @param dValues - the map used for dynamic programming.
	 * @param sum - the object collecting the sum of s(n) for each node in the tree.
	 * @return d(c_i)
	 * @throws Exception
	 */
	protected double[] dRecursive(List<Tree> c, int i, HashMap<Integer, double[]> dValues, Spectrum sum) throws Exception {
		if (dValues.containsKey(i))
			return dValues.get(i);
		double[] sci = sRecursive(c.get(i), sum);
		double[] result;
		if (i < c.size()-1) {
			double[] total = ArrayMath.scalardot(mu, dRecursive(c, i+1, dValues, sum));
			for (int k=i+2; k<c.size(); k++)
				total = ArrayMath.sum(total, ArrayMath.scalardot(getMuPow(k-i), dRecursive(c, k, dValues, sum)));
			result = ArrayMath.sum(sci, op(sci, total));
		}
		else
			result = sci;
		dValues.put(i, result);
		return result;
	}

}
